{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import re\n",
    "import collections\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_textcleaning(string):\n",
    "    string = re.sub('[^A-Za-z ]+', ' ', string)\n",
    "    return re.sub(r'[ ]+', ' ', string.lower()).strip()\n",
    "\n",
    "def batch_sequence(sentences, dictionary, maxlen = 50):\n",
    "    np_array = np.zeros((len(sentences), maxlen), dtype = np.int32)\n",
    "    for no_sentence, sentence in enumerate(sentences):\n",
    "        current_no = 0\n",
    "        for no, word in enumerate(sentence.split()[: maxlen - 2]):\n",
    "            np_array[no_sentence, no] = dictionary.get(word, 1)\n",
    "            current_no = no\n",
    "        np_array[no_sentence, current_no + 1] = 3\n",
    "    return np_array\n",
    "\n",
    "def counter_words(sentences):\n",
    "    word_counter = collections.Counter()\n",
    "    word_list = []\n",
    "    num_lines, num_words = (0, 0)\n",
    "    for i in sentences:\n",
    "        words = re.findall('[\\\\w\\']+|[;:\\-\\(\\)&.,!?\"]', i)\n",
    "        word_counter.update(words)\n",
    "        word_list.extend(words)\n",
    "        num_lines += 1\n",
    "        num_words += len(words)\n",
    "    return word_counter, word_list, num_lines, num_words\n",
    "\n",
    "\n",
    "def build_dict(word_counter, vocab_size = 50000):\n",
    "    count = [['PAD', 0], ['UNK', 1], ['START', 2], ['END', 3]]\n",
    "    count.extend(word_counter.most_common(vocab_size))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    return dictionary, {word: idx for idx, word in dictionary.items()}\n",
    "\n",
    "def split_by_dot(string):\n",
    "    string = re.sub(\n",
    "        r'(?<!\\d)\\.(?!\\d)',\n",
    "        'SPLITTT',\n",
    "        string.replace('\\n', '').replace('/', ' '),\n",
    "    )\n",
    "    string = string.split('SPLITTT')\n",
    "    return [re.sub(r'[ ]+', ' ', sentence).strip() for sentence in string]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9923"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "contents = []\n",
    "with open('books/Blood_Born') as fopen:\n",
    "    contents.extend(split_by_dot(fopen.read()))\n",
    "    \n",
    "with open('books/Dark_Thirst') as fopen:\n",
    "    contents.extend(split_by_dot(fopen.read()))\n",
    "    \n",
    "len(contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8390"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "contents = [simple_textcleaning(sentence) for sentence in contents]\n",
    "contents = [sentence for sentence in contents if len(sentence) > 20]\n",
    "len(contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9039"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "maxlen = 50\n",
    "vocabulary_size = len(set(' '.join(contents).split()))\n",
    "embedding_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "vocabulary_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils import shuffle\n",
    "\n",
    "stride = 1\n",
    "t_range = int((len(contents) - 3) / stride + 1)\n",
    "left, middle, right = [], [], []\n",
    "for i in range(t_range):\n",
    "    slices = contents[i * stride : i * stride + 3]\n",
    "    left.append(slices[0])\n",
    "    middle.append(slices[1])\n",
    "    right.append(slices[2])\n",
    "\n",
    "left, middle, right = shuffle(left, middle, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_counter, _, _, _ = counter_words(middle)\n",
    "dictionary, _ = build_dict(word_counter, vocab_size = vocabulary_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Attention:\n",
    "    def __init__(self,hidden_size):\n",
    "        self.hidden_size = hidden_size\n",
    "        self.dense_layer = tf.layers.Dense(hidden_size)\n",
    "        self.v = tf.random_normal([hidden_size],mean=0,stddev=1/np.sqrt(hidden_size))\n",
    "        \n",
    "    def score(self, hidden_tensor, encoder_outputs):\n",
    "        energy = tf.nn.tanh(self.dense_layer(tf.concat([hidden_tensor,encoder_outputs],2)))\n",
    "        energy = tf.transpose(energy,[0,2,1])\n",
    "        batch_size = tf.shape(encoder_outputs)[0]\n",
    "        v = tf.expand_dims(tf.tile(tf.expand_dims(self.v,0),[batch_size,1]),1)\n",
    "        energy = tf.matmul(v,energy)\n",
    "        return tf.squeeze(energy,1)\n",
    "    \n",
    "    def __call__(self, hidden, encoder_outputs):\n",
    "        seq_len = tf.shape(encoder_outputs)[1]\n",
    "        batch_size = tf.shape(encoder_outputs)[0]\n",
    "        H = tf.tile(tf.expand_dims(hidden, 1),[1,seq_len,1])\n",
    "        attn_energies = self.score(H,encoder_outputs)\n",
    "        return tf.expand_dims(tf.nn.softmax(attn_energies),1)\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dict_size,\n",
    "        size_layers,\n",
    "        learning_rate,\n",
    "        maxlen,\n",
    "        num_blocks = 3,\n",
    "    ):\n",
    "        block_size = size_layers\n",
    "        self.BEFORE = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.INPUT = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.AFTER = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.batch_size = tf.shape(self.INPUT)[0]\n",
    "        self.output_layer = tf.layers.Dense(dict_size, name=\"output_layer\")\n",
    "        self.output_layer.build(size_layers)\n",
    "        self.embeddings = tf.Variable(tf.random_uniform([dict_size, size_layers], -1, 1))\n",
    "        embedded = tf.nn.embedding_lookup(self.embeddings, self.INPUT)\n",
    "        self.attention = Attention(size_layers)\n",
    "\n",
    "        def residual_block(x, size, rate, block, reuse = False):\n",
    "            with tf.variable_scope(\n",
    "                'block_%d_%d' % (block, rate), reuse = reuse\n",
    "            ):\n",
    "                attn_weights = self.attention(tf.reduce_sum(x,axis=1), x)\n",
    "                conv_filter = tf.layers.conv1d(\n",
    "                    attn_weights,\n",
    "                    x.shape[2] // 4,\n",
    "                    kernel_size = size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    dilation_rate = rate,\n",
    "                    activation = tf.nn.tanh,\n",
    "                )\n",
    "                conv_gate = tf.layers.conv1d(\n",
    "                    x,\n",
    "                    x.shape[2] // 4,\n",
    "                    kernel_size = size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    dilation_rate = rate,\n",
    "                    activation = tf.nn.sigmoid,\n",
    "                )\n",
    "                out = tf.multiply(conv_filter, conv_gate)\n",
    "                out = tf.layers.conv1d(\n",
    "                    out,\n",
    "                    block_size,\n",
    "                    kernel_size = 1,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    activation = tf.nn.tanh,\n",
    "                )\n",
    "                return tf.add(x, out), out\n",
    "\n",
    "        forward = tf.layers.conv1d(\n",
    "            embedded, block_size, kernel_size = 1, strides = 1, padding = 'SAME'\n",
    "        )\n",
    "        zeros = tf.zeros_like(forward)\n",
    "        for i in range(num_blocks):\n",
    "            for r in [1, 2, 4, 8, 16]:\n",
    "                forward, s = residual_block(\n",
    "                    forward, size = 7, rate = r, block = i\n",
    "                )\n",
    "                zeros = tf.add(zeros, s)\n",
    "        forward = tf.layers.conv1d(\n",
    "            zeros,\n",
    "            block_size,\n",
    "            kernel_size = 1,\n",
    "            strides = 1,\n",
    "            padding = 'SAME',\n",
    "            activation = tf.nn.tanh,\n",
    "        )\n",
    "        self.get_thought = tf.reduce_sum(forward,axis=1, name = 'logits')\n",
    "        \n",
    "        def decoder(labels, reuse):\n",
    "            decoder_in = tf.nn.embedding_lookup(self.embeddings, labels)\n",
    "            forward = tf.layers.conv1d(\n",
    "                decoder_in, block_size, kernel_size = 1, strides = 1, padding = 'SAME'\n",
    "            )\n",
    "            zeros = tf.zeros_like(forward)\n",
    "            for r in [8, 16, 24]:\n",
    "                forward, s = residual_block(forward, size = 7, rate = r, block = 10, reuse = reuse)\n",
    "                zeros = tf.add(zeros, s)\n",
    "            return tf.layers.conv1d(\n",
    "                zeros,\n",
    "                block_size,\n",
    "                kernel_size = 1,\n",
    "                strides = 1,\n",
    "                padding = 'SAME',\n",
    "                activation = tf.nn.tanh,\n",
    "            )\n",
    "        \n",
    "        fw_logits = decoder(self.AFTER, False)\n",
    "        bw_logits = decoder(self.BEFORE, True)\n",
    "        self.attention = tf.matmul(\n",
    "            self.get_thought, tf.transpose(self.embeddings), name = 'attention'\n",
    "        )\n",
    "        self.loss = self.calculate_loss(fw_logits, self.AFTER) + self.calculate_loss(bw_logits, self.BEFORE)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)\n",
    "    \n",
    "    def calculate_loss(self, outputs, labels):\n",
    "        mask = tf.cast(tf.sign(labels), tf.float32)\n",
    "        logits = self.output_layer(outputs)\n",
    "        return tf.contrib.seq2seq.sequence_loss(logits, labels, mask)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(len(dictionary), embedding_size, learning_rate, maxlen)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 525/525 [00:23<00:00, 21.91it/s, cost=8.56]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:23<00:00, 23.81it/s, cost=5.04]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:23<00:00, 23.78it/s, cost=3.43]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:23<00:00, 23.77it/s, cost=2.48]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:23<00:00, 23.71it/s, cost=1.92]\n"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    pbar = tqdm(range(0, len(middle), batch_size), desc='train minibatch loop')\n",
    "    for p in pbar:\n",
    "        index = min(p + batch_size, len(middle))\n",
    "        batch_x = batch_sequence(\n",
    "                middle[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        batch_y_before = batch_sequence(\n",
    "                left[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        batch_y_after = batch_sequence(\n",
    "                right[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        loss, _ = sess.run([model.loss, model.optimizer], \n",
    "                           feed_dict = {model.BEFORE: batch_y_before,\n",
    "                                        model.INPUT: batch_x,\n",
    "                                        model.AFTER: batch_y_after,})\n",
    "        pbar.set_postfix(cost=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('books/Driftas_Quest') as f:\n",
    "    book = f.read()\n",
    "\n",
    "book = split_by_dot(book)\n",
    "book = [simple_textcleaning(sentence) for sentence in book]\n",
    "book = [sentence for sentence in book if len(sentence) > 20][100:200]\n",
    "book_sequences = batch_sequence(book, dictionary, maxlen = maxlen)\n",
    "encoded, attention = sess.run([model.get_thought, model.attention],feed_dict={model.INPUT:book_sequences})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grofinglaz was a dead weight but drifta heaved him over his shoulders and staggered with him to the hova gently laying the dazed miner across the seats. you say that now but one day you ll want to find that special someone to settle down with and a regular income would be good. as grofinglaz recovered in the sickbay the crew laboured on drifta teaming up with a pair of miners and worked harder than ever to earn creds not for himself but the injured grofinglaz. when it was time to leave mars drifta had gained the captain s permission to once more sit with him and the pilot on the flight deck to take a last look at earth. but there was a discrepancy. sometime in the night someone had breached port security and placed a bio crib amongst the packages. this didn t go unnoticed amongst the crew. keep one eye on the seam the other on the roof. kargondov had been the medical officer and as most medical officers had other duties. as a stack was brought into one of the holds the precise weight was verified and a significant error from the inventory log appeared\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import pairwise_distances_argmin_min\n",
    "\n",
    "n_clusters = 10\n",
    "kmeans = KMeans(n_clusters=n_clusters, random_state=0)\n",
    "kmeans = kmeans.fit(encoded)\n",
    "avg = []\n",
    "closest = []\n",
    "for j in range(n_clusters):\n",
    "    idx = np.where(kmeans.labels_ == j)[0]\n",
    "    avg.append(np.mean(idx))\n",
    "closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_,encoded)\n",
    "ordering = sorted(range(n_clusters), key=lambda k: avg[k])\n",
    "print('. '.join([book[closest[idx]] for idx in ordering]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Important words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['fleeting',\n",
       " 'parked',\n",
       " 'expectantly',\n",
       " 'own',\n",
       " 'emotionless',\n",
       " 'pizzeria',\n",
       " 'veiled',\n",
       " 'nearly',\n",
       " 'tradition',\n",
       " 'frenzied']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices = np.argsort(attention.mean(axis=0))[::-1]\n",
    "rev_dictionary = {v:k for k, v in dictionary.items()}\n",
    "[rev_dictionary[i] for i in indices[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
